Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/more losses #845

Merged
merged 15 commits into from
Mar 19, 2022
Merged

Feat/more losses #845

merged 15 commits into from
Mar 19, 2022

Conversation

hrzn
Copy link
Contributor

@hrzn hrzn commented Mar 14, 2022

Add two new PyTorch loss functions (SmapeLoss and MapeLoss), which can provide different criteria and could for instance be used to replicate some of the M3/M4 competition results.

@codecov-commenter
Copy link

codecov-commenter commented Mar 15, 2022

Codecov Report

Merging #845 (f9dc3f1) into master (a7abedf) will increase coverage by 0.03%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master     #845      +/-   ##
==========================================
+ Coverage   91.40%   91.43%   +0.03%     
==========================================
  Files          70       71       +1     
  Lines        7106     7135      +29     
==========================================
+ Hits         6495     6524      +29     
  Misses        611      611              
Impacted Files Coverage Δ
darts/utils/losses.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update a7abedf...f9dc3f1. Read the comment docs.

super().__init__()

def forward(self, inpt, tgt):
return torch.mean(torch.abs(inpt - tgt))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this just the MAE? Or is this to overcome some of the issues with MAPE?

Suggested change
return torch.mean(torch.abs(inpt - tgt))
return torch.mean(torch.abs(_divide_no_nan(inpt - tgt, inpt)))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

@hrzn hrzn Mar 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. Initially I ignored the denominator because it's impacting only the magnitude of the gradients, and it was giving somewhat better results, but it's not quite correct.
I have change it and also added MAE to the list now (unit test on its way) :)

air_s = scaler.fit_transform(air)
air_train, air_val = air_s[:-36], air_s[-36:]

def test_smape_loss(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we check the actual output of the losses instead of fitting the models?
Just thinking about execution time, it takes a couple of seconds

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I thought of that as well. Although actually using the loss functions for fitting might reveal some problems that we wouldn't notice otherwise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's a tiny bit better to test the fitting to make sure the gradients are kept where they should, so we can leave it like that for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thoughts, I think your idea is better, as long as we're also checking the loss gradients. I've changed the tests to do that now, thanks for the suggestion 👍

Copy link
Contributor

@brunnedu brunnedu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice addition! 👍

air_s = scaler.fit_transform(air)
air_train, air_val = air_s[:-36], air_s[-36:]

def test_smape_loss(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I thought of that as well. Although actually using the loss functions for fitting might reveal some problems that we wouldn't notice otherwise.

darts/utils/losses.py Show resolved Hide resolved
super().__init__()

def forward(self, inpt, tgt):
return torch.mean(torch.abs(inpt - tgt))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, thanks a lot!
After addressing the last suggestions, it can be merged


def helper_test_loss(self, exp_loss_val, exp_w_grad, loss_fn):
W = torch.tensor([[0.1, -0.2, 0.3, -0.4], [-0.8, 0.7, -0.6, 0.5]])
W.requires_grad = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice tests +1

lval = loss_fn(y_hat, self.y)
lval.backward()

print(lval)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be removed

darts/utils/losses.py Outdated Show resolved Hide resolved
@hrzn hrzn merged commit eda8f94 into master Mar 19, 2022
@madtoinou madtoinou deleted the feat/more-losses branch July 5, 2023 21:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants